from transformers import AutoModel, AutoTokenizer
import streamlit as st
import glob
import chromadb
from text2vec import SentenceModel

st.set_page_config(
    page_title="ChatGLM2-6b 演示",
    page_icon=":robot:",
    layout='wide'
)

@st.cache_resource
def get_vectordb():
    model = SentenceModel('shibing624/text2vec-base-chinese')
    client = chromadb.Client()
    
    texts = []
    for filename in glob.glob(f"texts/*.txt"):
        with open(filename, encoding='utf-8') as f:
            texts.append(f.read())
    embeddings = model.encode(texts).tolist()
    
    collection = client.get_or_create_collection("testname")
    ids = [f'id{x+1}' for x in range(len(texts))]
    collection.add(ids=ids, embeddings=embeddings, documents=texts)
    return collection,model

@st.cache_resource
def get_model():
    tokenizer = AutoTokenizer.from_pretrained(r"D:\LLM\ChatGLM2-6B-main\model", trust_remote_code=True)
    model = AutoModel.from_pretrained(r"D:\LLM\ChatGLM2-6B-main\model", trust_remote_code=True).half().cuda()

    model = model.eval()
    return tokenizer, model


def query_related(text:str, model:SentenceModel, coll):
    embedding = model.encode(text).tolist()
    result =  coll.query( query_embeddings=embedding, n_results=1)['documents']
    return result[0][0]


tokenizer, model = get_model()
collection, t2v = get_vectordb()

st.title("ChatGLM2-6B-int4 外挂知识库")

max_length = st.sidebar.slider(
    'max_length', 0, 32768, 32768, step=1
)
top_p = st.sidebar.slider(
    'top_p', 0.0, 1.0, 0.8, step=0.01
)
temperature = st.sidebar.slider(
    'temperature', 0.0, 1.0, 0.8, step=0.01
)

if 'history' not in st.session_state:
    st.session_state.history = []

if 'past_key_values' not in st.session_state:
    st.session_state.past_key_values = None

for i, (query, response) in enumerate(st.session_state.history):
    with st.chat_message(name="user", avatar="user"):
        st.markdown(query)
    with st.chat_message(name="assistant", avatar="assistant"):
        st.markdown(response)
with st.chat_message(name="user", avatar="user"):
    input_placeholder = st.empty()
with st.chat_message(name="assistant", avatar="assistant"):
    message_placeholder = st.empty()

prompt_text = st.text_area(label="用户命令输入",
                           height=100,
                           placeholder="请在这儿输入您的命令")

button = st.button("发送", key="predict")

if button:
    input_placeholder.markdown(prompt_text)
    related_text = query_related(prompt_text, t2v, collection)
    prompt_text = f"'''\n{related_text}\n''' \n请上文提取信息并回答：“{prompt_text}”"
    #print(prompt_text)
    history, past_key_values = st.session_state.history, st.session_state.past_key_values
    for response, history, past_key_values in model.stream_chat(tokenizer, prompt_text, history,
                                                                past_key_values=past_key_values,
                                                                max_length=max_length, top_p=top_p,
                                                                temperature=temperature,
                                                                return_past_key_values=True):
        message_placeholder.markdown(response)
    
    st.session_state.history = history
    st.session_state.past_key_values = past_key_values
